import numpy as np
import random
from vllm import LLM, SamplingParams
import gc
import math
import csv
import fire
import torch
import torch.distributed
import os
from typing import List
from tqdm import tqdm
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM

from utils.dataset_utils import get_dataset
from utils.icl_utils import get_icl_examples
from utils.prompt_utils import apply_prompt_template, apply_icl_prompt

from utils.model_utils import (
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    load_peft_model,
)

import json
import re
import time
from datetime import datetime

def print_start_time():
    start_time = datetime.now()
    print(f"Program start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    return start_time

def print_end_time(start_time):
    end_time = datetime.now()
    total_duration = end_time - start_time
    print(f"Program end time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Total runtime: {total_duration}")

def save(output_file, out):
    with open(output_file, 'w') as f:
        for li in out:
            f.write(json.dumps(li))
            f.write("\n")

def combine_results(input_files: List[str], output_file: str):
    results = []
    for input_file in input_files:
        with open(input_file, 'r') as f:
            for line in f:
                results.append(json.loads(line))
    results.sort(key=lambda x: x['idx'])
    for result in results:
        del result['idx']
    save(output_file, results)
def question_read_csv(text_file):
    dataset = []
    file = open(text_file, "r")
    data = list(csv.reader(file, delimiter=","))
    file.close()
    num = len(data)
    for i in range(num):
        dataset.append(data[i][0])
    
    return dataset
def question_read_json(text_file, prompt_key):
    with open(text_file, 'r') as file:
        data = json.load(file)
    return [(line[prompt_key] if isinstance(line, dict) else line) for line in data]
def question_read_txt(text_file):
    with open(text_file, 'r') as file:
        return [line.strip() for line in file if line.strip()]
def main(
    model_path,
    train_dataset: str="gsm8k",
    test_dataset: str="gsm8k",
    data_start: int=0,
    data_end: int=-1,
    peft_model: str=None,
    quantization: bool=False,
    max_new_tokens = 256,
    prompt_template_style: str='gsm8k',
    seed: int=42,
    do_sample: bool=True,
    use_cache: bool=True,
    top_p: float=1.0,
    temperature: float=1.0,
    top_k: int=0,
    repetition_penalty: float=1.0,
    length_penalty: int=1,
    use_fast_kernels: bool = False,
    prompt_key: str = 'instruction',
    output: str = None,
    exp_num: int=10,
    k: int=4,
    subset_size: int=100,
    method: str = "diversity",
    dp_choice: str = "knn",
    metric: str = "cosine_similarity",
    emb: str = None,
    freq: int = 0,
    if_qwa: bool = False,
    permutation: int = 1,
    apply_chat_template: bool = False,
    **kwargs
):
    if(test_dataset == "gsm8k" and train_dataset == "gsm8k"):
        max_new_tokens = 1024
    elif(test_dataset == "prm800k" and train_dataset == "prm800k"):
        max_new_tokens = 1024
    else:
        max_new_tokens = 1024
    print(f"model_path: {model_path}")
    print(f"if_qwa: {if_qwa}")
    model_name = os.path.basename(model_path)
        
    start_time = print_start_time()
    seed = exp_num 

    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    world_size = os.environ.get('WORLD_SIZE')
    if world_size is None:
        world_size = 1
        local_rank = 0
        rank = 0
    else:
        world_size = int(world_size)
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ['RANK'])
        print(f"rank: {rank} local rank: {local_rank} world size: {world_size}")

        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        
        setup()
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)

    model = LLM(model=model_path, dtype=torch.bfloat16)
    sampling_params = SamplingParams(
        temperature=1.0,
        top_p=1.0,
        top_k=1, 
        max_tokens=max_new_tokens,
        repetition_penalty=repetition_penalty,
    )

    if use_fast_kernels:
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)    
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")


    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    
    train_inputs, train_outputs, _, _ = get_dataset(dataset=train_dataset, load_from_local=True)
    _, _, test_inputs, _ = get_dataset(dataset=test_dataset, load_from_local=True)
    _, train_inputs_prompts = apply_prompt_template(prompt_template_style, train_inputs, tokenizer, return_dialogs=True)
    _, test_input_prompts = apply_prompt_template(prompt_template_style, test_inputs, tokenizer, return_dialogs=True)

    idx_mat_origin = get_icl_examples(train_dataset=train_dataset, test_dataset=test_dataset, emb=emb, shuffle_seed=seed,k=k, method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric,if_qwa=if_qwa,model_name=model_name,permutation=permutation)
    
    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)
        
    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1

    if("knn" in method or "knn_diversity" in method or "k_means" in method):
        permutation_num = per_total
    else:
        permutation_num = 1

    for permutation_index in range(0, permutation_num):
        program_info = f"train_dataset: {train_dataset}, test_dataset: {test_dataset}, model_name: {model_name}, emb: {emb}, method: {method}, dp_choice: {dp_choice}, metric: {metric}, exp_num: {exp_num}, subset_size: {subset_size}, k: {k}, permutation: {permutation_index}, seed: {seed}, if_qwa: {if_qwa}"
        print(f"program info: {program_info}")

        if(k>0 and method != "diversity" and method != "random"):
            idx_mat = idx_mat_origin[:, np.array(permutations[permutation_index])]
        else:
            idx_mat = idx_mat_origin
            
        if (if_qwa):
            output = f"results/qwa/{method}/{model_name}/{test_dataset}/{train_dataset}/{k}/{permutation_index}/{seed}"
        else:
            output = f"results/{method}/{model_name}/{test_dataset}/{train_dataset}/{k}/{permutation_index}/{seed}"
        if not os.path.exists(output):
            os.makedirs(output)
        
        if os.path.exists(f"{output}/{emb}.jsonl"):
            os.remove(f"{output}/{emb}.jsonl")
            print(f"{output}/{emb}.jsonl has been deleted.")
        else:
            print(f"{output}/{emb}.jsonl not exists, no need to delete.")
            
        dialogs = apply_icl_prompt(test_input_prompts, train_inputs_prompts, train_outputs, idx_mat, k, model_name, test_dataset, apply_chat_template)

        if apply_chat_template:
            msgs = []
            for d in dialogs:
                msg = [{"role":"user","content":d}]
                msgs.append(msg)
            dialogs = [tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize = False) for msg in msgs]

        question_dataset = test_inputs
        
        results = []
        batch = { 'dialogs': [], 'idx': [] }
        for idx, dialog in tqdm(list(enumerate(dialogs))):
            if idx % world_size == rank:
                batch['dialogs'].append(dialog)
                batch['idx'].append(idx)
                
            if (freq > 0 and idx % freq == 0) or idx == len(dialogs) - 1:
                with torch.no_grad():
                    inputs = tokenizer(batch['dialogs'], return_tensors="pt", padding=True).to('cuda')
                    
                    outputs = model.generate(batch['dialogs'],sampling_params=sampling_params)
                    output_text = [completion.text for output in outputs for completion in output.outputs]

                    torch.cuda.empty_cache()
                    for i, o in zip(batch["idx"], output_text):
                        cur = {'prompt': question_dataset[i], 'answer': o}
                        if world_size > 1:
                            cur['idx'] = i
                        results.append(cur)
                        

                batch = { 'dialogs': [], 'idx': [] }
                if output is not None:
                    if world_size > 1:
                        save(output_file=f"{output}/{emb}.part{rank}.jsonl", out=results)
                        torch.distributed.barrier()
                        if rank == 0:
                            combine_results(
                                input_files=[f"{output}/{emb}.part{i}.jsonl" for i in range(world_size)],
                                output_file=f"{output}/{emb}.jsonl")
                    else:
                        save(output_file=f"{output}/{emb}.jsonl", out=results)
    del model.llm_engine.model_executor
    del model
    gc.collect()
    print_end_time(start_time)

if __name__ == "__main__":
    fire.Fire(main)